import os
import pyaudio
import wave
import time
from libs import DB
from pydub import AudioSegment
import noisereduce as nr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.style as style

db = DB()

# 设置录音参数
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
CHUNK = 1024
RECORD_SECONDS = 4

class Recorder:
    def __init__(self):
        # 初始化录音参数等
        self.frames = []
        self.audio = None

    async def record_audio(self):
        self.frames = []
        print("Recording...")
        # 录音逻辑
        WAVE_OUTPUT_FILENAME = f"output_{int(time.time())}.wav"
        # 创建 PyAudio 对象
        self.audio = pyaudio.PyAudio()

        # 打开音频流
        stream = self.audio.open(
            format=FORMAT,
            channels=CHANNELS,
            rate=RATE,
            input=True,
            # input_device_index=0,
            frames_per_buffer=CHUNK)
        
        # 录音
        for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
            data = stream.read(CHUNK)
            self.frames.append(data)
        print("Finished recording.")
        # 停止录音
        stream.stop_stream()
        stream.close()
        self.audio.terminate()
        
        return WAVE_OUTPUT_FILENAME

    async def save_audio(self, filename):
        # 保存录音文件
        wf = wave.open("./wavs/" + filename, 'wb')
        wf.setnchannels(CHANNELS)
        wf.setsampwidth(self.audio.get_sample_size(FORMAT))
        wf.setframerate(RATE)
        wf.writeframes(b''.join(self.frames))
        wf.close()

        await db.set_audio_name(filename)
        print(f"Recording saved as {filename}")
    
    async def denoise_audio(self, filename):
        style.use('ggplot')

        # Load audio file
        audio = AudioSegment.from_file("./wavs/" + filename)

        # Convert audio to numpy array
        samples = np.array(audio.get_array_of_samples())

        # Reduce noise
        reduced_noise = nr.reduce_noise(samples, sr=audio.frame_rate)

        # Plot original and reduced noise signals
        fig, ax = plt.subplots(2, 1, figsize=(15,8))
        ax[0].set_title("Original signal")
        ax[0].plot(samples)
        ax[1].set_title("Reduced noise signal")
        ax[1].plot(reduced_noise)
        plt.show()

        # Convert reduced noise signal back to audio
        reduced_audio = AudioSegment(
            reduced_noise.tobytes(), 
            frame_rate=audio.frame_rate, 
            sample_width=audio.sample_width, 
            channels=audio.channels
        )

        # Save reduced audio to file
        reduced_audio.export("./wavs/denoise_" + filename, format="wav")


    async def delete_audio(self, filename):
        try:
            os.remove("./wavs/" + filename)
            os.remove("./wavs/denoise_" + filename)
            print(f"File {filename} deleted successfully")
        except OSError as e:
            print(f"Error deleting file {filename}: {e}")
    